import os

from django.contrib import auth
from django.contrib.auth.decorators import login_required, user_passes_test
from django.contrib.auth.models import User, Group
from django.core.paginator import Paginator, PageNotAnInteger, EmptyPage
from django.shortcuts import render
from django.http import HttpResponseRedirect
from django.urls import reverse

from management.models import Fund, Stock, Rank_stock

# coding=utf-8
import pandas as pd
import numpy as np
import math

from sklearn.linear_model import LinearRegression
from sklearn.model_selection import train_test_split
from bokeh.embed import components
from bokeh.plotting import figure


def index(request):
    user = request.user if request.user.is_authenticated else None
    objs = Rank_stock.objects.all()
    # context = {
    #     'active_menu': 'homepage',
    #     'user': user
    # }
    return render(request, 'management/index.html', locals())


def sign_up(request):
    if request.user.is_authenticated:
        return HttpResponseRedirect(reverse('homepage'))
    state = None
    if request.method == 'POST':
        password = request.POST.get('password', '')
        repeat_password = request.POST.get('repeat_password', '')
        if password == '' or repeat_password == '':
            state = 'empty'
        elif password != repeat_password:
            state = 'repeat_error'
        else:
            username = request.POST.get('username', '')
            if User.objects.filter(username=username):
                state = 'user_exist'
            else:
                new_user = User.objects.create_user(username=username, password=password,
                                                    email=request.POST.get('email', ''))
                new_user.save()
                state = 'success'
    context = {
        'active_menu': 'homepage',
        'state': state,
        'user': None
    }
    return render(request, 'management/sign_up.html', context)


def login(request):
    if request.user.is_authenticated:
        return HttpResponseRedirect(reverse('homepage'))
    state = None
    if request.method == 'POST':
        username = request.POST.get('username', '')
        password = request.POST.get('password', '')
        user = auth.authenticate(username=username, password=password)
        if user is not None:
            auth.login(request, user)
            target_url = request.GET.get('next', reverse('homepage'))
            return HttpResponseRedirect(target_url)
        else:
            state = 'not_exist_or_password_error'
    context = {
        'active_menu': 'homepage',
        'state': state,
        'user': None
    }
    return render(request, 'management/login.html', context)


def logout(request):
    auth.logout(request)
    return HttpResponseRedirect(reverse('homepage'))


@login_required
def change_password(request):
    user = request.user
    state = None
    if request.method == 'POST':
        old_password = request.POST.get('old_password', '')
        new_password = request.POST.get('new_password', '')
        repeat_password = request.POST.get('repeat_password', '')
        if user.check_password(old_password):
            if not new_password:
                state = 'empty'
            elif new_password != repeat_password:
                state = 'repeat_error'
            else:
                user.set_password(new_password)
                user.save()
                state = 'success'
        else:
            state = 'password_error'
    content = {
        'user': user,
        'active_menu': 'homepage',
        'state': state,
    }
    return render(request, 'management/change_password.html', content)


@login_required
def fund_list(request, buy_fee_rate='all'):
    user = request.user
    fee_rate_list = Fund.objects.values_list('buy_fee_rate', flat=True).distinct()
    if Fund.objects.filter(buy_fee_rate=buy_fee_rate).count() == 0:
        buy_fee_rate = 'all'
        funds = Fund.objects.all()
    else:
        funds = Fund.objects.filter(buy_fee_rate=buy_fee_rate)

    paginator = Paginator(funds, 5)
    page = request.GET.get('page')
    try:
        funds = paginator.page(page)
    except PageNotAnInteger:
        funds = paginator.page(1)
    except EmptyPage:
        funds = paginator.page(paginator.num_pages)
    #
    context = {
        'user': user,
        'active_menu': 'view_fund',
        'fee_rate_list': fee_rate_list,
        'query_fee_rate': buy_fee_rate,
        'fund_list': funds
    }
    return render(request, 'management/fund_list.html', context)


@login_required
def fund_detail(request, fund_id=1):
    user = request.user
    try:
        fund = Fund.objects.get(pk=fund_id)

    except Fund.DoesNotExist:
        return HttpResponseRedirect(reverse('fund_list', args=('all',)))
    content = {
        'user': user,
        'active_menu': 'view_fund',
        'fund': fund,
    }
    return render(request, 'management/fund_detail.html', content)


@login_required
# 上传文件
def add_fund(request):
    """上传csv文件储存并进行分析"""

    if request.method == "POST":
        f = request.FILES['csv_file']
        # 下载位置
        file_path = os.path.join('media/upload', f.name)
        with open(file_path, 'wb') as fp:
            for info in f.chunks():
                fp.write(info)
            fp.close()
        # bokeh网页嵌入
        script, div = data_analyse(f.name)
        context = {
            'script': script,
            'div': div,
            'name': f.name[0:6]
        }

        return render(request, 'management/data_analyse.html', context)
    return render(request, 'management/add_fund.html')


def data_analyse(file_name, test_size=0.1):
    """sklearn linear regression数据分析"""
    f = file_name
    size = test_size
    # 读取csv文件前500行
    origDf = pd.read_csv('D:/djangoPro/fund/Fund/media/upload/' + f, encoding='gbk', nrows=500)
    # 按日期升序
    origDf['日期'] = pd.to_datetime(origDf['日期'])
    origDf = origDf.sort_values(by='日期').reset_index(drop=True)

    df = origDf[['收盘价', '最高价', '最低价', '开盘价', '成交量']]
    # print(df)

    featureData = df[['开盘价', '最高价', '成交量', '最低价']]
    # 划分特征值和目标值
    feature = featureData.values
    target = np.array(df['收盘价'])
    # 划分训练集，测试集
    feature_train, feature_test, target_train, target_test = train_test_split(feature, target, test_size=size)
    pridectedDays = int(math.ceil(0.05 * len(origDf)))  # 预测天数
    # print(pridectedDays)
    lrTool = LinearRegression()
    lrTool.fit(feature_train, target_train)  # 训练
    # 用测试集预测结果
    predictByTest = lrTool.predict(feature_test)
    # print(predictByTest)
    # 组装数据
    #
    # average_predict = float(sum(predictByTest))/len(predictByTest) * 3/4
    #
    # for i in range(len(predictByTest)):
    #     predictByTest[i] = 1.70*average_predict + 0.15*predictByTest[i]

    index = 0
    # 在前95%的交易日中，设置预测结果和收盘价一致
    while index < len(origDf) - pridectedDays:
        df.loc[index, 'predictedVal'] = origDf.loc[index, '收盘价']
        df.loc[index, '日期'] = origDf.loc[index, '日期']
        index = index + 1

    predictedCnt = 0
    # 在后5%的交易日中，用测试集推算预测股价
    while predictedCnt < pridectedDays:
        df.loc[index, 'predictedVal'] = predictByTest[predictedCnt]
        df.loc[index, '日期'] = origDf.loc[index, '日期']
        predictedCnt = predictedCnt + 1
        index = index + 1

    # print(df['predictedVal'])
    # print(df['Close'])
    # print(df['Date'])

    p = figure(title='线性回归预测股票/基金走势（拟合对比）',
               x_axis_label='日期(data)', y_axis_label='价格(price)',
               x_axis_type="datetime", plot_width=800, plot_height=400,
               )
    # p.circle(df['Date'], df['Close'], legend_label="Real Data", line_color='blue', )
    p.line(df['日期'], df['收盘价'], legend_label="Real Data", line_color='blue', )

    p.circle(df['日期'], df['predictedVal'], legend_label="Predicted Data", line_color='red', )
    p.line(df['日期'], df['predictedVal'], legend_label="Predicted Data", line_color='red', )

    p.legend.location = 'top_left'
    p.legend.border_line_width = 3
    p.legend.border_line_color = "navy"
    p.legend.border_line_alpha = 0.5

    p1 = figure(title='未来十天走势预测',
                x_axis_label='未来十天', y_axis_label='价格(price)',
                plot_width=800, plot_height=400)
    p1.scatter(list(range(10)), predictByTest[predictedCnt:len(predictByTest)], size=12,
               legend_label="Predicted Data", color="red", alpha=0.5)
    p1.line(list(range(10)), predictByTest[predictedCnt:len(predictByTest)],
            legend_label="Predicted Data", line_color='red', )
    p1.legend.location = 'top_left'

    # show(p)
    # plots = {'1': p, '2': p1}
    plots = {'1': p, '2': p1}
    script, div = components(plots)
    return script, div


@user_passes_test(lambda u: u.is_staff)
def user_list(request):
    user_queryset = User.objects.only('username', 'is_active', 'email', 'is_staff', 'is_superuser')
    groups = Group.objects.only('name').all()
    query_dict = {}
    # 检索
    groups__id = request.GET.get('group')
    if groups__id:
        try:
            groups__id = int(groups__id)
            query_dict['groups__id'] = groups__id
        except Exception as e:
            pass

    is_staff = request.GET.get('is_staff')
    if is_staff == '0':
        query_dict['is_staff'] = False
    if is_staff == '1':
        query_dict['is_staff'] = True

    is_superuser = request.GET.get('is_superuser')
    if is_superuser == '0':
        query_dict['is_superuser'] = False
    if is_superuser == '1':
        query_dict['is_superuser'] = True

    username = request.GET.get('username')

    if username:
        query_dict['username'] = username

    try:
        page = int(request.GET.get('page', 1))
    except Exception as e:
        page = 1

    paginater = Paginator(user_queryset.filter(**query_dict), 4)

    users = paginater.get_page(page)
    context = {
        'users': users,
        'groups': groups
    }
    context.update(query_dict)
    return render(request, 'management/user_list.html', context=context)
